{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp data.core"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.torch_basics import *\n",
    "from fastai.data.load import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data core\n",
    "\n",
    "> Core functionality for gathering data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The classes here provide functionality for applying a list of transforms to a set of items (`TfmdLists`, `Datasets`) or a `DataLoader` (`TfmdDl`) as well as the base class used to gather the data for model training: `DataLoaders`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TfmdDL -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_batch(x, y, samples, ctxs=None, max_n=9, **kwargs):\n",
    "    if ctxs is None: ctxs = Inf.nones\n",
    "    if hasattr(samples[0], 'show'):\n",
    "        ctxs = [s.show(ctx=c, **kwargs) for s,c,_ in zip(samples,ctxs,range(max_n))]\n",
    "    else:\n",
    "        for i in range_of(samples[0]):\n",
    "            ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`show_batch` is a type-dispatched function that is responsible for showing decoded `samples`. `x` and `y` are the input and the target in the batch to be shown, and are passed along to dispatch on their types. There is a different implementation of `show_batch` if `x` is a `TensorImage` or a `TensorText` for instance (see vision.core or text.data for more details). `ctxs` can be passed but the function is responsible to create them if necessary. `kwargs` depend on the specific implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x, y, samples, outs, ctxs=None, max_n=9, **kwargs):\n",
    "    if ctxs is None: ctxs = Inf.nones\n",
    "    for i in range(len(samples[0])):\n",
    "        ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n",
    "    for i in range(len(outs[0])):\n",
    "        ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(i),ctxs,range(max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`show_results` is a type-dispatched function that is responsible for showing decoded `samples` and their corresponding `outs`. Like in `show_batch`, `x` and `y` are the input and the target in the batch to be shown, and are passed along to dispatch on their types. `ctxs` can be passed but the function is responsible to create them if necessary. `kwargs` depend on the specific implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_all_ = [\"show_batch\", \"show_results\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_batch_tfms = ('after_item','before_batch','after_batch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@log_args(but_as=DataLoader.__init__)\n",
    "@delegates()\n",
    "class TfmdDL(DataLoader):\n",
    "    \"Transformed `DataLoader`\"\n",
    "    def __init__(self, dataset, bs=64, shuffle=False, num_workers=None, verbose=False, do_setup=True, **kwargs):\n",
    "        if num_workers is None: num_workers = min(16, defaults.cpus)\n",
    "        for nm in _batch_tfms: kwargs[nm] = Pipeline(kwargs.get(nm,None))\n",
    "        super().__init__(dataset, bs=bs, shuffle=shuffle, num_workers=num_workers, **kwargs)\n",
    "        if do_setup:\n",
    "            for nm in _batch_tfms:\n",
    "                pv(f\"Setting up {nm}: {kwargs[nm]}\", verbose)\n",
    "                kwargs[nm].setup(self)\n",
    "\n",
    "    def _one_pass(self):\n",
    "        b = self.do_batch([self.do_item(0)])\n",
    "        if self.device is not None: b = to_device(b, self.device)\n",
    "        its = self.after_batch(b)\n",
    "        self._n_inp = 1 if not isinstance(its, (list,tuple)) or len(its)==1 else len(its)-1\n",
    "        self._types = explode_types(its)\n",
    "\n",
    "    def _retain_dl(self,b):\n",
    "        if not getattr(self, '_types', None): self._one_pass()\n",
    "        return retain_types(b, typs=self._types)\n",
    "\n",
    "    @delegates(DataLoader.new)\n",
    "    def new(self, dataset=None, cls=None, **kwargs):\n",
    "        res = super().new(dataset, cls, do_setup=False, **kwargs)\n",
    "        if not hasattr(self, '_n_inp') or not hasattr(self, '_types'):\n",
    "            try:\n",
    "                self._one_pass()\n",
    "                res._n_inp,res._types = self._n_inp,self._types\n",
    "            except: print(\"Could not do one pass in your dataloader, there is something wrong in it\")\n",
    "        else: res._n_inp,res._types = self._n_inp,self._types\n",
    "        return res\n",
    "\n",
    "    def before_iter(self):\n",
    "        super().before_iter()\n",
    "        split_idx = getattr(self.dataset, 'split_idx', None)\n",
    "        for nm in _batch_tfms:\n",
    "            f = getattr(self,nm)\n",
    "            if isinstance(f,Pipeline): f.split_idx=split_idx\n",
    "\n",
    "    def decode(self, b): return self.before_batch.decode(to_cpu(self.after_batch.decode(self._retain_dl(b))))\n",
    "    def decode_batch(self, b, max_n=9, full=True): return self._decode_batch(self.decode(b), max_n, full)\n",
    "\n",
    "    def _decode_batch(self, b, max_n=9, full=True):\n",
    "        f = self.after_item.decode\n",
    "        f = compose(f, partial(getattr(self.dataset,'decode',noop), full = full))\n",
    "        return L(batch_to_samples(b, max_n=max_n)).map(f)\n",
    "\n",
    "    def _pre_show_batch(self, b, max_n=9):\n",
    "        \"Decode `b` to be ready for `show_batch`\"\n",
    "        b = self.decode(b)\n",
    "        if hasattr(b, 'show'): return b,None,None\n",
    "        its = self._decode_batch(b, max_n, full=False)\n",
    "        if not is_listy(b): b,its = [b],L((o,) for o in its)\n",
    "        return detuplify(b[:self.n_inp]),detuplify(b[self.n_inp:]),its\n",
    "\n",
    "    def show_batch(self, b=None, max_n=9, ctxs=None, show=True, unique=False, **kwargs):\n",
    "        if unique:\n",
    "            old_get_idxs = self.get_idxs\n",
    "            self.get_idxs = lambda: Inf.zeros\n",
    "        if b is None: b = self.one_batch()\n",
    "        if not show: return self._pre_show_batch(b, max_n=max_n)\n",
    "        show_batch(*self._pre_show_batch(b, max_n=max_n), ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "        if unique: self.get_idxs = old_get_idxs\n",
    "\n",
    "    def show_results(self, b, out, max_n=9, ctxs=None, show=True, **kwargs):\n",
    "        x,y,its = self.show_batch(b, max_n=max_n, show=False)\n",
    "        b_out = type(b)(b[:self.n_inp] + (tuple(out) if is_listy(out) else (out,)))\n",
    "        x1,y1,outs = self.show_batch(b_out, max_n=max_n, show=False)\n",
    "        res = (x,x1,None,None) if its is None else (x, y, its, outs.itemgot(slice(self.n_inp,None)))\n",
    "        if not show: return res\n",
    "        show_results(*res, ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "\n",
    "    @property\n",
    "    def n_inp(self):\n",
    "        if hasattr(self.dataset, 'n_inp'): return self.dataset.n_inp\n",
    "        if not hasattr(self, '_n_inp'): self._one_pass()\n",
    "        return self._n_inp\n",
    "\n",
    "    def to(self, device):\n",
    "        self.device = device\n",
    "        for tfm in self.after_batch.fs:\n",
    "            for a in L(getattr(tfm, 'parameters', None)): setattr(tfm, a, getattr(tfm, a).to(device))\n",
    "        return self"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A `TfmdDL` is a `DataLoader` that creates `Pipeline` from a list of `Transform`s for the callbacks `after_item`, `before_batch` and `after_batch`. As a result, it can decode or show a processed `batch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "add_docs(TfmdDL,\n",
    "         decode=\"Decode `b` using `tfms`\",\n",
    "         decode_batch=\"Decode `b` entirely\",\n",
    "         new=\"Create a new version of self with a few changed attributes\",\n",
    "         show_batch=\"Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a `DataLoader`)\",\n",
    "         show_results=\"Show each item of `b` and `out`\",\n",
    "         before_iter=\"override\",\n",
    "         to=\"Put self and its transforms state on `device`\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _Category(int, ShowTitle): pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test retain type\n",
    "class NegTfm(Transform):\n",
    "    def encodes(self, x): return torch.neg(x)\n",
    "    def decodes(self, x): return torch.neg(x)\n",
    "    \n",
    "tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)\n",
    "b = tdl.one_batch()\n",
    "test_eq(type(b[0]), TensorImage)\n",
    "b = (tensor([1.,1.,1.,1.]),)\n",
    "test_eq(type(tdl.decode_batch(b)[0][0]), TensorImage)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class A(Transform): \n",
    "    def encodes(self, x): return x \n",
    "    def decodes(self, x): return TitledInt(x) \n",
    "\n",
    "@Transform\n",
    "def f(x)->None: return fastuple((x,x))\n",
    "\n",
    "start = torch.arange(50)\n",
    "test_eq_type(f(2), fastuple((2,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = A()\n",
    "tdl = TfmdDL(start, after_item=lambda x: (a(x), f(x)), bs=4)\n",
    "x,y = tdl.one_batch()\n",
    "test_eq(type(y), fastuple)\n",
    "\n",
    "s = tdl.decode_batch((x,y))\n",
    "test_eq(type(s[0][1]), fastuple)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tdl = TfmdDL(torch.arange(0,50), after_item=A(), after_batch=NegTfm(), bs=4)\n",
    "test_eq(tdl.dataset[0], start[0])\n",
    "test_eq(len(tdl), (50-1)//4+1)\n",
    "test_eq(tdl.bs, 4)\n",
    "test_stdout(tdl.show_batch, '0\\n1\\n2\\n3')\n",
    "test_stdout(partial(tdl.show_batch, unique=True), '0\\n0\\n0\\n0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class B(Transform):\n",
    "    parameters = 'a'\n",
    "    def __init__(self): self.a = torch.tensor(0.)\n",
    "    def encodes(self, x): x\n",
    "    \n",
    "tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=B(), bs=4)\n",
    "test_eq(tdl.after_batch.fs[0].a.device, torch.device('cpu'))\n",
    "tdl.to(default_device())\n",
    "test_eq(tdl.after_batch.fs[0].a.device, default_device())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataLoader.one_batch\" class=\"doc_header\"><code>DataLoader.one_batch</code><a href=\"https://github.com/fastai/fastai/tree/master/fastai/data/load.py#L135\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>DataLoader.one_batch</code>()\n",
       "\n",
       "Return one batch from [`DataLoader`](/data.load.html#DataLoader)."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TfmdDL.one_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfm = NegTfm()\n",
    "tdl = TfmdDL(start, after_batch=tfm, bs=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = tdl.one_batch()\n",
    "test_eq(tensor([0,-1,-2,-3]), b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TfmdDL.decode\" class=\"doc_header\"><code>TfmdDL.decode</code><a href=\"__main__.py#L44\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TfmdDL.decode</code>(**`b`**)\n",
       "\n",
       "Decode `b` using `tfms`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TfmdDL.decode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(tdl.decode(b), tensor(0,1,2,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TfmdDL.decode_batch\" class=\"doc_header\"><code>TfmdDL.decode_batch</code><a href=\"__main__.py#L45\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TfmdDL.decode_batch</code>(**`b`**, **`max_n`**=*`9`*, **`full`**=*`True`*)\n",
       "\n",
       "Decode `b` entirely"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TfmdDL.decode_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(tdl.decode_batch(b), [0,1,2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TfmdDL.show_batch\" class=\"doc_header\"><code>TfmdDL.show_batch</code><a href=\"__main__.py#L60\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TfmdDL.show_batch</code>(**`b`**=*`None`*, **`max_n`**=*`9`*, **`ctxs`**=*`None`*, **`show`**=*`True`*, **`unique`**=*`False`*, **\\*\\*`kwargs`**)\n",
       "\n",
       "Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a [`DataLoader`](/data.load.html#DataLoader))"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TfmdDL.show_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TfmdDL.to\" class=\"doc_header\"><code>TfmdDL.to</code><a href=\"__main__.py#L83\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TfmdDL.to</code>(**`device`**)\n",
       "\n",
       "Put self and its transforms state on `device`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TfmdDL.to)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DataLoaders -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "@docs\n",
    "class DataLoaders(GetAttr):\n",
    "    \"Basic wrapper around several `DataLoader`s.\"\n",
    "    _default='train'\n",
    "    def __init__(self, *loaders, path='.', device=None):\n",
    "        self.loaders,self.path = list(loaders),Path(path)\n",
    "        if device is not None or hasattr(loaders[0],'to'): self.device = device\n",
    "\n",
    "    def __getitem__(self, i): return self.loaders[i]\n",
    "    def new_empty(self):\n",
    "        loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]\n",
    "        return type(self)(*loaders, path=self.path, device=self.device)\n",
    "\n",
    "    def _set(i, self, v): self.loaders[i] = v\n",
    "    train   ,valid    = add_props(lambda i,x: x[i], _set)\n",
    "    train_ds,valid_ds = add_props(lambda i,x: x[i].dataset)\n",
    "\n",
    "    @property\n",
    "    def device(self): return self._device\n",
    "\n",
    "    @device.setter\n",
    "    def device(self, d):\n",
    "        for dl in self.loaders: dl.to(d)\n",
    "        self._device = d\n",
    "\n",
    "    def to(self, device):\n",
    "        self.device = device\n",
    "        return self\n",
    "\n",
    "    def cuda(self): return self.to(device=default_device())\n",
    "    def cpu(self):  return self.to(device=torch.device('cpu'))\n",
    "\n",
    "    @classmethod\n",
    "    def from_dsets(cls, *ds, path='.',  bs=64, device=None, dl_type=TfmdDL, **kwargs):\n",
    "        default = (True,) + (False,) * (len(ds)-1)\n",
    "        defaults = {'shuffle': default, 'drop_last': default}\n",
    "        for nm in _batch_tfms:\n",
    "            if nm in kwargs: kwargs[nm] = Pipeline(kwargs[nm])\n",
    "        kwargs = merge(defaults, {k: tuplify(v, match=ds) for k,v in kwargs.items()})\n",
    "        kwargs = [{k: v[i] for k,v in kwargs.items()} for i in range_of(ds)]\n",
    "        return cls(*[dl_type(d, bs=bs, **k) for d,k in zip(ds, kwargs)], path=path, device=device)\n",
    "\n",
    "    @classmethod\n",
    "    def from_dblock(cls, dblock, source, path='.',  bs=64, val_bs=None, shuffle_train=True, device=None, **kwargs):\n",
    "        return dblock.dataloaders(source, path=path, bs=bs, val_bs=val_bs, shuffle_train=shuffle_train, device=device, **kwargs)\n",
    "\n",
    "    _docs=dict(__getitem__=\"Retrieve `DataLoader` at `i` (`0` is training, `1` is validation)\",\n",
    "               train=\"Training `DataLoader`\",\n",
    "               valid=\"Validation `DataLoader`\",\n",
    "               train_ds=\"Training `Dataset`\",\n",
    "               valid_ds=\"Validation `Dataset`\",\n",
    "               to=\"Use `device`\",\n",
    "               cuda=\"Use the gpu if available\",\n",
    "               cpu=\"Use the cpu\",\n",
    "               new_empty=\"Create a new empty version of `self` with the same transforms\",\n",
    "               from_dblock=\"Create a dataloaders from a given `dblock`\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = DataLoaders(tdl,tdl)\n",
    "x = dls.train.one_batch()\n",
    "x2 = first(tdl)\n",
    "test_eq(x,x2)\n",
    "x2 = dls.one_batch()\n",
    "test_eq(x,x2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#test assignment works\n",
    "dls.train = dls.train.new(bs=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataLoaders.__getitem__\" class=\"doc_header\"><code>DataLoaders.__getitem__</code><a href=\"__main__.py#L10\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>DataLoaders.__getitem__</code>(**`i`**)\n",
       "\n",
       "Retrieve [`DataLoader`](/data.load.html#DataLoader) at `i` (`0` is training, `1` is validation)"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(DataLoaders.__getitem__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x2 = dls[0].one_batch()\n",
    "test_eq(x,x2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataLoaders.train\" class=\"doc_header\"><code>DataLoaders.train</code><a href=\"__main__.py#L16\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "Training [`DataLoader`](/data.load.html#DataLoader)"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(DataLoaders.train, name=\"DataLoaders.train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataLoaders.valid\" class=\"doc_header\"><code>DataLoaders.valid</code><a href=\"__main__.py#L16\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "Validation [`DataLoader`](/data.load.html#DataLoader)"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(DataLoaders.valid, name=\"DataLoaders.valid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataLoaders.train_ds\" class=\"doc_header\"><code>DataLoaders.train_ds</code><a href=\"__main__.py#L17\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "Training `Dataset`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(DataLoaders.train_ds, name=\"DataLoaders.train_ds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataLoaders.valid_ds\" class=\"doc_header\"><code>DataLoaders.valid_ds</code><a href=\"__main__.py#L17\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "Validation `Dataset`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(DataLoaders.valid_ds, name=\"DataLoaders.valid_ds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TfmdLists -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class FilteredBase:\n",
    "    \"Base class for lists with subsets\"\n",
    "    _dl_type,_dbunch_type = TfmdDL,DataLoaders\n",
    "    def __init__(self, *args, dl_type=None, **kwargs):\n",
    "        if dl_type is not None: self._dl_type = dl_type\n",
    "        self.dataloaders = delegates(self._dl_type.__init__)(self.dataloaders)\n",
    "        super().__init__(*args, **kwargs)\n",
    "\n",
    "    @property\n",
    "    def n_subsets(self): return len(self.splits)\n",
    "    def _new(self, items, **kwargs): return super()._new(items, splits=self.splits, **kwargs)\n",
    "    def subset(self): raise NotImplemented\n",
    "\n",
    "    def dataloaders(self, bs=64, val_bs=None, shuffle_train=True, n=None, path='.', dl_type=None, dl_kwargs=None,\n",
    "                    device=None, **kwargs):\n",
    "        if device is None: device=default_device()\n",
    "        if dl_kwargs is None: dl_kwargs = [{}] * self.n_subsets\n",
    "        if dl_type is None: dl_type = self._dl_type\n",
    "        drop_last = kwargs.pop('drop_last', shuffle_train)\n",
    "        dl = dl_type(self.subset(0), bs=bs, shuffle=shuffle_train, drop_last=drop_last, n=n, device=device,\n",
    "                     **merge(kwargs, dl_kwargs[0]))\n",
    "        dls = [dl] + [dl.new(self.subset(i), bs=(bs if val_bs is None else val_bs), shuffle=False, drop_last=False,\n",
    "                             n=None, **dl_kwargs[i]) for i in range(1, self.n_subsets)]\n",
    "        return self._dbunch_type(*dls, path=path, device=device)\n",
    "\n",
    "FilteredBase.train,FilteredBase.valid = add_props(lambda i,x: x.subset(i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TfmdLists(FilteredBase, L, GetAttr):\n",
    "    \"A `Pipeline` of `tfms` applied to a collection of `items`\"\n",
    "    _default='tfms'\n",
    "    def __init__(self, items, tfms, use_list=None, do_setup=True, split_idx=None, train_setup=True,\n",
    "                 splits=None, types=None, verbose=False, dl_type=None):\n",
    "        super().__init__(items, use_list=use_list)\n",
    "        if dl_type is not None: self._dl_type = dl_type\n",
    "        self.splits = L([slice(None),[]] if splits is None else splits).map(mask2idxs)\n",
    "        if isinstance(tfms,TfmdLists): tfms = tfms.tfms\n",
    "        if isinstance(tfms,Pipeline): do_setup=False\n",
    "        self.tfms = Pipeline(tfms, split_idx=split_idx)\n",
    "        store_attr('types,split_idx')\n",
    "        if do_setup:\n",
    "            pv(f\"Setting up {self.tfms}\", verbose)\n",
    "            self.setup(train_setup=train_setup)\n",
    "\n",
    "    def _new(self, items, split_idx=None, **kwargs):\n",
    "        split_idx = ifnone(split_idx,self.split_idx)\n",
    "        return super()._new(items, tfms=self.tfms, do_setup=False, types=self.types, split_idx=split_idx, **kwargs)\n",
    "    def subset(self, i): return self._new(self._get(self.splits[i]), split_idx=i)\n",
    "    def _after_item(self, o): return self.tfms(o)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}: {self.items}\\ntfms - {self.tfms.fs}\"\n",
    "    def __iter__(self): return (self[i] for i in range(len(self)))\n",
    "    def show(self, o, **kwargs): return self.tfms.show(o, **kwargs)\n",
    "    def decode(self, o, **kwargs): return self.tfms.decode(o, **kwargs)\n",
    "    def __call__(self, o, **kwargs): return self.tfms.__call__(o, **kwargs)\n",
    "    def overlapping_splits(self): return L(Counter(self.splits.concat()).values()).filter(gt(1))\n",
    "    def new_empty(self): return self._new([])\n",
    "\n",
    "    def setup(self, train_setup=True):\n",
    "        self.tfms.setup(self, train_setup)\n",
    "        if len(self) != 0:\n",
    "            x = super().__getitem__(0) if self.splits is None else super().__getitem__(self.splits[0])[0]\n",
    "            self.types = []\n",
    "            for f in self.tfms.fs:\n",
    "                self.types.append(getattr(f, 'input_types', type(x)))\n",
    "                x = f(x)\n",
    "            self.types.append(type(x))\n",
    "        types = L(t if is_listy(t) else [t] for t in self.types).concat().unique()\n",
    "        self.pretty_types = '\\n'.join([f'  - {t}' for t in types])\n",
    "\n",
    "    def infer_idx(self, x):\n",
    "        # TODO: check if we really need this, or can simplify\n",
    "        idx = 0\n",
    "        for t in self.types:\n",
    "            if isinstance(x, t): break\n",
    "            idx += 1\n",
    "        types = L(t if is_listy(t) else [t] for t in self.types).concat().unique()\n",
    "        pretty_types = '\\n'.join([f'  - {t}' for t in types])\n",
    "        assert idx < len(self.types), f\"Expected an input of type in \\n{pretty_types}\\n but got {type(x)}\"\n",
    "        return idx\n",
    "\n",
    "    def infer(self, x):\n",
    "        return compose_tfms(x, tfms=self.tfms.fs[self.infer_idx(x):], split_idx=self.split_idx)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        res = super().__getitem__(idx)\n",
    "        if self._after_item is None: return res\n",
    "        return self._after_item(res) if is_indexer(idx) else res.map(self._after_item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "add_docs(TfmdLists,\n",
    "         setup=\"Transform setup with self\",\n",
    "         decode=\"From `Pipeline`\",\n",
    "         show=\"From `Pipeline`\",\n",
    "         overlapping_splits=\"All splits that are in more than one split\",\n",
    "         subset=\"New `TfmdLists` with same tfms that only includes items in `i`th split\",\n",
    "         infer_idx=\"Finds the index where `self.tfms` can be applied to `x`, depending on the type of `x`\",\n",
    "         infer=\"Apply `self.tfms` to `x` starting at the right tfm depending on the type of `x`\",\n",
    "         new_empty=\"A new version of `self` but with no items\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#exports\n",
    "def decode_at(o, idx):\n",
    "    \"Decoded item at `idx`\"\n",
    "    return o.decode(o[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#exports\n",
    "def show_at(o, idx, **kwargs):\n",
    "    \"Show item at `idx`\",\n",
    "    return o.show(o[idx], **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A `TfmdLists` combines a collection of object with a `Pipeline`. `tfms` can either be a `Pipeline` or a list of transforms, in which case, it will wrap them in a `Pipeline`. `use_list` is passed along to `L` with the `items` and `split_idx` are passed to each transform of the `Pipeline`. `do_setup` indicates if the `Pipeline.setup` method should be called during initialization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _IntFloatTfm(Transform):\n",
    "    def encodes(self, o):  return TitledInt(o)\n",
    "    def decodes(self, o):  return TitledFloat(o)\n",
    "int2f_tfm=_IntFloatTfm()\n",
    "\n",
    "def _neg(o): return -o\n",
    "neg_tfm = Transform(_neg, _neg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TfmdLists: [1.0, 2.0, 3.0]\n",
       "tfms - (#2) [_neg:\n",
       "encodes: (object,object) -> _negdecodes: (object,object) -> _neg,_IntFloatTfm:\n",
       "encodes: (object,object) -> encodes\n",
       "decodes: (object,object) -> decodes\n",
       "]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "items = L([1.,2.,3.]); tfms = [neg_tfm, int2f_tfm]\n",
    "tl = TfmdLists(items, tfms=tfms)\n",
    "test_eq_type(tl[0], TitledInt(-1))\n",
    "test_eq_type(tl[1], TitledInt(-2))\n",
    "test_eq_type(tl.decode(tl[2]), TitledFloat(3.))\n",
    "test_stdout(lambda: show_at(tl, 2), '-3')\n",
    "test_eq(tl.types, [float, float, TitledInt])\n",
    "tl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add splits to TfmdLists\n",
    "splits = [[0,2],[1]]\n",
    "tl = TfmdLists(items, tfms=tfms, splits=splits)\n",
    "test_eq(tl.n_subsets, 2)\n",
    "test_eq(tl.train, tl.subset(0))\n",
    "test_eq(tl.valid, tl.subset(1))\n",
    "test_eq(tl.train.items, items[splits[0]])\n",
    "test_eq(tl.valid.items, items[splits[1]])\n",
    "test_eq(tl.train.tfms.split_idx, 0)\n",
    "test_eq(tl.valid.tfms.split_idx, 1)\n",
    "test_eq(tl.train.new_empty().split_idx, 0)\n",
    "test_eq(tl.valid.new_empty().split_idx, 1)\n",
    "test_eq_type(tl.splits, L(splits))\n",
    "assert not tl.overlapping_splits()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(dict(a=[1,2,3],b=[2,3,4]))\n",
    "tl = TfmdLists(df, lambda o: o.a+1, splits=[[0],[1,2]])\n",
    "test_eq(tl[1,2], [3,4])\n",
    "tr = tl.subset(0)\n",
    "test_eq(tr[:], [2])\n",
    "val = tl.subset(1)\n",
    "test_eq(val[:], [3,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TfmdLists: [1.0, 2.0, 3.0]\n",
      "tfms - (#0) []\n"
     ]
    }
   ],
   "source": [
    "class _B(Transform):\n",
    "    def __init__(self): self.m = 0\n",
    "    def encodes(self, o): return o+self.m\n",
    "    def decodes(self, o): return o-self.m\n",
    "    def setups(self, items): \n",
    "        print(items)\n",
    "        self.m = tensor(items).float().mean().item()\n",
    "\n",
    "# test for setup, which updates `self.m`\n",
    "tl = TfmdLists(items, _B())\n",
    "test_eq(tl.m, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's how we can use `TfmdLists.setup` to implement a simple category list, getting labels from a mock file list:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _Cat(Transform):\n",
    "    order = 1\n",
    "    def encodes(self, o):    return int(self.o2i[o])\n",
    "    def decodes(self, o):    return TitledStr(self.vocab[o])\n",
    "    def setups(self, items): self.vocab,self.o2i = uniqueify(L(items), sort=True, bidir=True)\n",
    "tcat = _Cat()\n",
    "\n",
    "def _lbl(o): return TitledStr(o.split('_')[0])\n",
    "\n",
    "# Check that tfms are sorted by `order` & `_lbl` is called first\n",
    "fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']\n",
    "tl = TfmdLists(fns, [tcat,_lbl])\n",
    "exp_voc = ['cat','dog']\n",
    "test_eq(tcat.vocab, exp_voc)\n",
    "test_eq(tl.tfms.vocab, exp_voc)\n",
    "test_eq(tl.vocab, exp_voc)\n",
    "test_eq(tl, (1,0,0,0,1))\n",
    "test_eq([tl.decode(o) for o in tl], ('dog','cat','cat','cat','dog'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Check only the training set is taken into account for setup\n",
    "tl = TfmdLists(fns, [tcat,_lbl], splits=[[0,4], [1,2,3]])\n",
    "test_eq(tcat.vocab, ['dog'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfm = NegTfm(split_idx=1)\n",
    "tds = TfmdLists(start, A())\n",
    "tdl = TfmdDL(tds, after_batch=tfm, bs=4)\n",
    "x = tdl.one_batch()\n",
    "test_eq(x, torch.arange(4))\n",
    "tds.split_idx = 1\n",
    "x = tdl.one_batch()\n",
    "test_eq(x, -torch.arange(4))\n",
    "tds.split_idx = 0\n",
    "x = tdl.one_batch()\n",
    "test_eq(x, torch.arange(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tds = TfmdLists(start, A())\n",
    "tdl = TfmdDL(tds, after_batch=NegTfm(), bs=4)\n",
    "test_eq(tdl.dataset[0], start[0])\n",
    "test_eq(len(tdl), (len(tds)-1)//4+1)\n",
    "test_eq(tdl.bs, 4)\n",
    "test_stdout(tdl.show_batch, '0\\n1\\n2\\n3')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TfmdLists.subset\" class=\"doc_header\"><code>TfmdLists.subset</code><a href=\"__main__.py#L21\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TfmdLists.subset</code>(**`i`**)\n",
       "\n",
       "New [`TfmdLists`](/data.core.html#TfmdLists) with same tfms that only includes items in `i`th split"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TfmdLists.subset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TfmdLists.infer_idx\" class=\"doc_header\"><code>TfmdLists.infer_idx</code><a href=\"__main__.py#L43\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TfmdLists.infer_idx</code>(**`x`**)\n",
       "\n",
       "Finds the index where `self.tfms` can be applied to `x`, depending on the type of `x`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TfmdLists.infer_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TfmdLists.infer\" class=\"doc_header\"><code>TfmdLists.infer</code><a href=\"__main__.py#L54\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TfmdLists.infer</code>(**`x`**)\n",
       "\n",
       "Apply `self.tfms` to `x` starting at the right tfm depending on the type of `x`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TfmdLists.infer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mult(x): return x*2\n",
    "mult.order = 2\n",
    "\n",
    "fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']\n",
    "tl = TfmdLists(fns, [_lbl,_Cat(),mult])\n",
    "\n",
    "test_eq(tl.infer_idx('dog_45.jpg'), 0)\n",
    "test_eq(tl.infer('dog_45.jpg'), 2)\n",
    "\n",
    "test_eq(tl.infer_idx(4), 2)\n",
    "test_eq(tl.infer(4), 8)\n",
    "\n",
    "test_fail(lambda: tl.infer_idx(2.0))\n",
    "test_fail(lambda: tl.infer(2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test input_types works on a Transform\n",
    "cat = _Cat()\n",
    "cat.input_types = (str, float)\n",
    "tl = TfmdLists(fns, [_lbl,cat,mult])\n",
    "test_eq(tl.infer_idx(2.0), 1)\n",
    "\n",
    "#Test type annotations work on a function\n",
    "def mult(x:(int,float)): return x*2\n",
    "mult.order = 2\n",
    "tl = TfmdLists(fns, [_lbl,_Cat(),mult])\n",
    "test_eq(tl.infer_idx(2.0), 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Datasets -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@docs\n",
    "@delegates(TfmdLists)\n",
    "class Datasets(FilteredBase):\n",
    "    \"A dataset that creates a tuple from each `tfms`, passed through `item_tfms`\"\n",
    "    def __init__(self, items=None, tfms=None, tls=None, n_inp=None, dl_type=None, **kwargs):\n",
    "        super().__init__(dl_type=dl_type)\n",
    "        self.tls = L(tls if tls else [TfmdLists(items, t, **kwargs) for t in L(ifnone(tfms,[None]))])\n",
    "        self.n_inp = ifnone(n_inp, max(1, len(self.tls)-1))\n",
    "\n",
    "    def __getitem__(self, it):\n",
    "        res = tuple([tl[it] for tl in self.tls])\n",
    "        return res if is_indexer(it) else list(zip(*res))\n",
    "\n",
    "    def __getattr__(self,k): return gather_attrs(self, k, 'tls')\n",
    "    def __dir__(self): return super().__dir__() + gather_attr_names(self, 'tls')\n",
    "    def __len__(self): return len(self.tls[0])\n",
    "    def __iter__(self): return (self[i] for i in range(len(self)))\n",
    "    def __repr__(self): return coll_repr(self)\n",
    "    def decode(self, o, full=True): return tuple(tl.decode(o_, full=full) for o_,tl in zip(o,tuplify(self.tls, match=o)))\n",
    "    def subset(self, i): return type(self)(tls=L(tl.subset(i) for tl in self.tls), n_inp=self.n_inp)\n",
    "    def _new(self, items, *args, **kwargs): return super()._new(items, tfms=self.tfms, do_setup=False, **kwargs)\n",
    "    def overlapping_splits(self): return self.tls[0].overlapping_splits()\n",
    "    def new_empty(self): return type(self)(tls=[tl.new_empty() for tl in self.tls], n_inp=self.n_inp)\n",
    "    @property\n",
    "    def splits(self): return self.tls[0].splits\n",
    "    @property\n",
    "    def split_idx(self): return self.tls[0].tfms.split_idx\n",
    "    @property\n",
    "    def items(self): return self.tls[0].items\n",
    "    @items.setter\n",
    "    def items(self, v):\n",
    "        for tl in self.tls: tl.items = v\n",
    "\n",
    "    def show(self, o, ctx=None, **kwargs):\n",
    "        for o_,tl in zip(o,self.tls): ctx = tl.show(o_, ctx=ctx, **kwargs)\n",
    "        return ctx\n",
    "\n",
    "    @contextmanager\n",
    "    def set_split_idx(self, i):\n",
    "        old_split_idx = self.split_idx\n",
    "        for tl in self.tls: tl.tfms.split_idx = i\n",
    "        try: yield self\n",
    "        finally:\n",
    "            for tl in self.tls: tl.tfms.split_idx = old_split_idx\n",
    "\n",
    "    _docs=dict(\n",
    "        decode=\"Compose `decode` of all `tuple_tfms` then all `tfms` on `i`\",\n",
    "        show=\"Show item `o` in `ctx`\",\n",
    "        dataloaders=\"Get a `DataLoaders`\",\n",
    "        overlapping_splits=\"All splits that are in more than one split\",\n",
    "        subset=\"New `Datasets` that only includes subset `i`\",\n",
    "        new_empty=\"Create a new empty version of the `self`, keeping only the transforms\",\n",
    "        set_split_idx=\"Contextmanager to use the same `Datasets` with another `split_idx`\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A `Datasets` creates a tuple from `items` (typically input,target) by applying to them each list of `Transform` (or `Pipeline`) in `tfms`. Note that if `tfms` contains only one list of `tfms`, the items given by `Datasets` will be tuples of one element. \n",
    "\n",
    "`n_inp` is the number of elements in the tuples that should be considered part of the input and will default to 1 if `tfms` consists of one set of transforms, `len(tfms)-1` otherwise. In most cases, the number of elements in the tuples spit out by `Datasets` will be 2 (for input,target) but it can happen that there is 3 (Siamese networks or tabular data) in which case we need to be able to determine when the inputs end and the targets begin."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 2)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "items = [1,2,3,4]\n",
    "dsets = Datasets(items, [[neg_tfm,int2f_tfm], [add(1)]])\n",
    "t = dsets[0]\n",
    "test_eq(t, (-1,2))\n",
    "test_eq(dsets[0,1,2], [(-1,2),(-2,3),(-3,4)])\n",
    "test_eq(dsets.n_inp, 1)\n",
    "dsets.decode(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Norm(Transform):\n",
    "    def encodes(self, o): return (o-self.m)/self.s\n",
    "    def decodes(self, o): return (o*self.s)+self.m\n",
    "    def setups(self, items):\n",
    "        its = tensor(items).float()\n",
    "        self.m,self.s = its.mean(),its.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "items = [1,2,3,4]\n",
    "nrm = Norm()\n",
    "dsets = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm,nrm]])\n",
    "\n",
    "x,y = zip(*dsets)\n",
    "test_close(tensor(y).mean(), 0)\n",
    "test_close(tensor(y).std(), 1)\n",
    "test_eq(x, (-1,-2,-3,-4,))\n",
    "test_eq(nrm.m, -2.5)\n",
    "test_stdout(lambda:show_at(dsets, 1), '-2')\n",
    "\n",
    "test_eq(dsets.m, nrm.m)\n",
    "test_eq(dsets.norm.m, nrm.m)\n",
    "test_eq(dsets.train.norm.m, nrm.m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Check filtering is properly applied\n",
    "class B(Transform):\n",
    "    def encodes(self, x)->None:  return int(x+1)\n",
    "    def decodes(self, x):        return TitledInt(x-1)\n",
    "add1 = B(split_idx=1)\n",
    "\n",
    "dsets = Datasets(items, [neg_tfm, [neg_tfm,int2f_tfm,add1]], splits=[[3],[0,1,2]])\n",
    "test_eq(dsets[1], [-2,-2])\n",
    "test_eq(dsets.valid[1], [-2,-1])\n",
    "test_eq(dsets.valid[[1,1]], [[-2,-1], [-2,-1]])\n",
    "test_eq(dsets.train[0], [-4,-4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','kid_1.jpg']\n",
    "tcat = _Cat()\n",
    "dsets = Datasets(test_fns, [[tcat,_lbl]], splits=[[0,1,2], [3,4]])\n",
    "test_eq(tcat.vocab, ['cat','dog'])\n",
    "test_eq(dsets.train, [(1,),(0,),(0,)])\n",
    "test_eq(dsets.valid[0], (0,))\n",
    "test_stdout(lambda: show_at(dsets.train, 0), \"dog\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inp = [0,1,2,3,4]\n",
    "dsets = Datasets(inp, tfms=[None])\n",
    "\n",
    "test_eq(*dsets[2], 2)          # Retrieve one item (subset 0 is the default)\n",
    "test_eq(dsets[1,2], [(1,),(2,)])    # Retrieve two items by index\n",
    "mask = [True,False,False,True,False]\n",
    "test_eq(dsets[mask], [(0,),(3,)])   # Retrieve two items by mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inp = pd.DataFrame(dict(a=[5,1,2,3,4]))\n",
    "dsets = Datasets(inp, tfms=attrgetter('a')).subset(0)\n",
    "test_eq(*dsets[2], 2)          # Retrieve one item (subset 0 is the default)\n",
    "test_eq(dsets[1,2], [(1,),(2,)])    # Retrieve two items by index\n",
    "mask = [True,False,False,True,False]\n",
    "test_eq(dsets[mask], [(5,),(3,)])   # Retrieve two items by mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#test n_inp\n",
    "inp = [0,1,2,3,4]\n",
    "dsets = Datasets(inp, tfms=[None])\n",
    "test_eq(dsets.n_inp, 1)\n",
    "dsets = Datasets(inp, tfms=[[None],[None],[None]])\n",
    "test_eq(dsets.n_inp, 2)\n",
    "dsets = Datasets(inp, tfms=[[None],[None],[None]], n_inp=1)\n",
    "test_eq(dsets.n_inp, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#5) [(0,),(1,),(2,),(3,),(4,)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# splits can be indices\n",
    "dsets = Datasets(range(5), tfms=[None], splits=[tensor([0,2]), [1,3,4]])\n",
    "\n",
    "test_eq(dsets.subset(0), [(0,),(2,)])\n",
    "test_eq(dsets.train, [(0,),(2,)])       # Subset 0 is aliased to `train`\n",
    "test_eq(dsets.subset(1), [(1,),(3,),(4,)])\n",
    "test_eq(dsets.valid, [(1,),(3,),(4,)])     # Subset 1 is aliased to `valid`\n",
    "test_eq(*dsets.valid[2], 4)\n",
    "#assert '[(1,),(3,),(4,)]' in str(dsets) and '[(0,),(2,)]' in str(dsets)\n",
    "dsets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# splits can be boolean masks (they don't have to cover all items, but must be disjoint)\n",
    "splits = [[False,True,True,False,True], [True,False,False,False,False]]\n",
    "dsets = Datasets(range(5), tfms=[None], splits=splits)\n",
    "\n",
    "test_eq(dsets.train, [(1,),(2,),(4,)])\n",
    "test_eq(dsets.valid, [(0,)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply transforms to all items\n",
    "tfm = [[lambda x: x*2,lambda x: x+1]]\n",
    "splits = [[1,2],[0,3,4]]\n",
    "dsets = Datasets(range(5), tfm, splits=splits)\n",
    "test_eq(dsets.train,[(3,),(5,)])\n",
    "test_eq(dsets.valid,[(1,),(7,),(9,)])\n",
    "test_eq(dsets.train[False,True], [(5,)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only transform subset 1\n",
    "class _Tfm(Transform):\n",
    "    split_idx=1\n",
    "    def encodes(self, x): return x*2\n",
    "    def decodes(self, x): return TitledStr(x//2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#5) [(0,),(1,),(2,),(3,),(4,)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dsets = Datasets(range(5), [_Tfm()], splits=[[1,2],[0,3,4]])\n",
    "test_eq(dsets.train,[(1,),(2,)])\n",
    "test_eq(dsets.valid,[(0,),(6,),(8,)])\n",
    "test_eq(dsets.train[False,True], [(2,)])\n",
    "dsets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#A context manager to change the split_idx and apply the validation transform on the training set\n",
    "ds = dsets.train\n",
    "with ds.set_split_idx(1):\n",
    "    test_eq(ds,[(2,),(4,)])\n",
    "test_eq(dsets.train,[(1,),(2,)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test Datasets pickles\n",
    "dsrc1 = pickle.loads(pickle.dumps(dsets))\n",
    "test_eq(dsets.train, dsrc1.train)\n",
    "test_eq(dsets.valid, dsrc1.valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsets = Datasets(range(5), [_Tfm(),noop], splits=[[1,2],[0,3,4]])\n",
    "test_eq(dsets.train,[(1,1),(2,2)])\n",
    "test_eq(dsets.valid,[(0,0),(6,3),(8,4)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = torch.arange(0,50)\n",
    "tds = Datasets(start, [A()])\n",
    "tdl = TfmdDL(tds, after_item=NegTfm(), bs=4)\n",
    "b = tdl.one_batch()\n",
    "test_eq(tdl.decode_batch(b), ((0,),(1,),(2,),(3,)))\n",
    "test_stdout(tdl.show_batch, \"0\\n1\\n2\\n3\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only transform subset 1\n",
    "class _Tfm(Transform):\n",
    "    split_idx=1\n",
    "    def encodes(self, x): return x*2\n",
    "\n",
    "dsets = Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only transform subset 1\n",
    "class _Tfm(Transform):\n",
    "    split_idx=1\n",
    "    def encodes(self, x): return x*2\n",
    "\n",
    "dsets = Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])\n",
    "dls = dsets.dataloaders(bs=4, after_batch=_Tfm(), shuffle_train=False, device=torch.device('cpu'))\n",
    "test_eq(dls.train, [(tensor([1,2,5, 7]),)])\n",
    "test_eq(dls.valid, [(tensor([0,6,8,12]),)])\n",
    "test_eq(dls.n_inp, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "items = [1,2,3,4]\n",
    "dsets = Datasets(items, [[neg_tfm,int2f_tfm]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Datasets.dataloaders\" class=\"doc_header\"><code>Datasets.dataloaders</code><a href=\"__main__.py#L15\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Datasets.dataloaders</code>(**`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`n`**=*`None`*, **`path`**=*`'.'`*, **`dl_type`**=*`None`*, **`dl_kwargs`**=*`None`*, **`device`**=*`None`*, **`shuffle`**=*`False`*, **`num_workers`**=*`None`*, **`verbose`**=*`False`*, **`do_setup`**=*`True`*, **`pin_memory`**=*`False`*, **`timeout`**=*`0`*, **`batch_size`**=*`None`*, **`drop_last`**=*`False`*, **`indexed`**=*`None`*, **`wif`**=*`None`*, **`before_iter`**=*`None`*, **`after_item`**=*`None`*, **`before_batch`**=*`None`*, **`after_batch`**=*`None`*, **`after_iter`**=*`None`*, **`create_batches`**=*`None`*, **`create_item`**=*`None`*, **`create_batch`**=*`None`*, **`retain`**=*`None`*, **`get_idxs`**=*`None`*, **`sample`**=*`None`*, **`shuffle_fn`**=*`None`*, **`do_batch`**=*`None`*)\n",
       "\n",
       "Get a [`DataLoaders`](/data.core.html#DataLoaders)"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide_input\n",
    "_dsrc = Datasets([1,2])\n",
    "show_doc(_dsrc.dataloaders, name=\"Datasets.dataloaders\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Datasets.decode\" class=\"doc_header\"><code>Datasets.decode</code><a href=\"__main__.py#L20\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Datasets.decode</code>(**`o`**, **`full`**=*`True`*)\n",
       "\n",
       "Compose `decode` of all `tuple_tfms` then all `tfms` on `i`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Datasets.decode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(*dsets[0], -1)\n",
    "test_eq(*dsets.decode((-1,)), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Datasets.show\" class=\"doc_header\"><code>Datasets.show</code><a href=\"__main__.py#L35\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Datasets.show</code>(**`o`**, **`ctx`**=*`None`*, **\\*\\*`kwargs`**)\n",
       "\n",
       "Show item `o` in `ctx`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Datasets.show)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_stdout(lambda:dsets.show(dsets[1]), '-2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Datasets.new_empty\" class=\"doc_header\"><code>Datasets.new_empty</code><a href=\"__main__.py#L24\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Datasets.new_empty</code>()\n",
       "\n",
       "Create a new empty version of the `self`, keeping only the transforms"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Datasets.new_empty)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "items = [1,2,3,4]\n",
    "nrm = Norm()\n",
    "dsets = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm]])\n",
    "empty = dsets.new_empty()\n",
    "test_eq(empty.items, [])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#test it works for dataframes too\n",
    "df = pd.DataFrame({'a':[1,2,3,4,5], 'b':[6,7,8,9,10]})\n",
    "dsets = Datasets(df, [[attrgetter('a')], [attrgetter('b')]])\n",
    "empty = dsets.new_empty()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add test set for inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only transform subset 1\n",
    "class _Tfm1(Transform):\n",
    "    split_idx=0\n",
    "    def encodes(self, x): return x*3\n",
    "\n",
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])\n",
    "test_eq(dsets.train, [(3,),(6,),(15,),(21,)])\n",
    "test_eq(dsets.valid, [(0,),(6,),(8,),(12,)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def test_set(dsets, test_items, rm_tfms=None, with_labels=False):\n",
    "    \"Create a test set from `test_items` using validation transforms of `dsets`\"\n",
    "    if isinstance(dsets, Datasets):\n",
    "        tls = dsets.tls if with_labels else dsets.tls[:dsets.n_inp]\n",
    "        test_tls = [tl._new(test_items, split_idx=1) for tl in tls]\n",
    "        if rm_tfms is None: rm_tfms = [tl.infer_idx(get_first(test_items)) for tl in test_tls]\n",
    "        else:               rm_tfms = tuplify(rm_tfms, match=test_tls)\n",
    "        for i,j in enumerate(rm_tfms): test_tls[i].tfms.fs = test_tls[i].tfms.fs[j:]\n",
    "        return Datasets(tls=test_tls)\n",
    "    elif isinstance(dsets, TfmdLists):\n",
    "        test_tl = dsets._new(test_items, split_idx=1)\n",
    "        if rm_tfms is None: rm_tfms = dsets.infer_idx(get_first(test_items))\n",
    "        test_tl.tfms.fs = test_tl.tfms.fs[rm_tfms:]\n",
    "        return test_tl\n",
    "    else: raise Exception(f\"This method requires using the fastai library to assemble your data. Expected a `Datasets` or a `TfmdLists` but got {dsets.__class__.__name__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _Tfm1(Transform):\n",
    "    split_idx=0\n",
    "    def encodes(self, x): return x*3\n",
    "\n",
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])\n",
    "test_eq(dsets.train, [(3,),(6,),(15,),(21,)])\n",
    "test_eq(dsets.valid, [(0,),(6,),(8,),(12,)])\n",
    "\n",
    "#Tranform of the validation set are applied\n",
    "tst = test_set(dsets, [1,2,3])\n",
    "test_eq(tst, [(2,),(4,),(6,)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test with different types\n",
    "tfm = _Tfm1()\n",
    "tfm.split_idx,tfm.order = None,2\n",
    "dsets = Datasets(['dog', 'cat', 'cat', 'dog'], [[_Cat(),tfm]])\n",
    "\n",
    "#With strings\n",
    "test_eq(test_set(dsets, ['dog', 'cat', 'cat']), [(3,), (0,), (0,)])\n",
    "#With ints\n",
    "test_eq(test_set(dsets, [1,2]), [(3,), (6,)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test with various input lengths\n",
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])\n",
    "tst = test_set(dsets, [1,2,3])\n",
    "test_eq(tst, [(2,2),(4,4),(6,6)])\n",
    "\n",
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]], n_inp=1)\n",
    "tst = test_set(dsets, [1,2,3])\n",
    "test_eq(tst, [(2,),(4,),(6,)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#Test with rm_tfms\n",
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]])\n",
    "tst = test_set(dsets, [1,2,3])\n",
    "test_eq(tst, [(4,),(8,),(12,)])\n",
    "\n",
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]])\n",
    "tst = test_set(dsets, [1,2,3], rm_tfms=1)\n",
    "test_eq(tst, [(2,),(4,),(6,)])\n",
    "\n",
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm()], [_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]], n_inp=2)\n",
    "tst = test_set(dsets, [1,2,3], rm_tfms=(1,0))\n",
    "test_eq(tst, [(2,4),(4,8),(6,12)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(TfmdDL.__init__)\n",
    "@patch\n",
    "def test_dl(self:DataLoaders, test_items, rm_type_tfms=None, with_labels=False, **kwargs):\n",
    "    \"Create a test dataloader from `test_items` using validation transforms of `dls`\"\n",
    "    test_ds = test_set(self.valid_ds, test_items, rm_tfms=rm_type_tfms, with_labels=with_labels\n",
    "                      ) if isinstance(self.valid_ds, (Datasets, TfmdLists)) else test_items\n",
    "    return self.valid.new(test_ds, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])\n",
    "dls = dsets.dataloaders(bs=4, device=torch.device('cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])\n",
    "dls = dsets.dataloaders(bs=4, device=torch.device('cpu'))\n",
    "tst_dl = dls.test_dl([2,3,4,5])\n",
    "test_eq(tst_dl._n_inp, 1)\n",
    "test_eq(list(tst_dl), [(tensor([ 4,  6,  8, 10]),)])\n",
    "#Test you can change transforms\n",
    "tst_dl = dls.test_dl([2,3,4,5], after_item=add1)\n",
    "test_eq(list(tst_dl), [(tensor([ 5,  7,  9, 11]),)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
